Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gather of pytorch #1622

Merged
merged 16 commits into from
May 6, 2022
Merged

Support gather of pytorch #1622

merged 16 commits into from
May 6, 2022

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented May 3, 2022

Description

This originiates from issue #248. This is regarding support of advanced indexing.

As part of this support, this PR implements gather. It's the same as the gather in
pytorh and numpy :

out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0
              or arr_{i, idx_{ijk}, k} if axis=1
              or arr_{i, j, idx_{ijk}}  if axis=2

But gather is used to index across axis, including the preceding axis. This conflicts with the linear storage of List<NDIndexElement> inside NDIndex. So gather is treated parallel to NDIndex when fed into get.
It currently supports pytorch engine only. MXNet and TensorFlow have different defination of gather.

Similar functions in numpy / pytorch / mxnet can also be implemented: e.g. torch.take, and advanced indexing like arr[:, [0 ,2]] and arr[idx_r, idx_c]

@KexinFeng KexinFeng marked this pull request as ready for review May 4, 2022 16:41
@KexinFeng KexinFeng marked this pull request as draft May 4, 2022 16:43
@KexinFeng KexinFeng marked this pull request as ready for review May 4, 2022 16:56
@codecov-commenter
Copy link

codecov-commenter commented May 4, 2022

Codecov Report

Merging #1622 (822c27b) into master (bb5073f) will decrease coverage by 1.20%.
The diff coverage is 62.15%.

@@             Coverage Diff              @@
##             master    #1622      +/-   ##
============================================
- Coverage     72.08%   70.87%   -1.21%     
- Complexity     5126     5433     +307     
============================================
  Files           473      507      +34     
  Lines         21970    23767    +1797     
  Branches       2351     2588     +237     
============================================
+ Hits          15838    16846    +1008     
- Misses         4925     5631     +706     
- Partials       1207     1290      +83     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <ø> (-5.24%) ⬇️
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) ⬆️
...nslator/InstanceSegmentationTranslatorFactory.java 14.28% <0.00%> (-3.90%) ⬇️
.../modality/cv/translator/YoloTranslatorFactory.java 8.33% <0.00%> (-1.67%) ⬇️
...i/djl/modality/cv/translator/YoloV5Translator.java 5.69% <0.00%> (ø)
...odality/cv/translator/YoloV5TranslatorFactory.java 8.33% <0.00%> (-1.67%) ⬇️
...pi/src/main/java/ai/djl/ndarray/BytesSupplier.java 54.54% <0.00%> (-12.13%) ⬇️
api/src/main/java/ai/djl/ndarray/NDArray.java 78.40% <ø> (+1.13%) ⬆️
...ain/java/ai/djl/ndarray/index/dim/NDIndexPick.java 100.00% <ø> (ø)
... and 241 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 05a685e...822c27b. Read the comment docs.

@@ -0,0 +1,51 @@
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

private int axis;

/**
* Constructs a new {@link NDIndexFullGather}.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Constructs a new {@link NDIndexFullGather}.
* Constructs a new {@code NDIndexFullGather} instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

@KexinFeng KexinFeng force-pushed the gather_dev branch 2 times, most recently from e25a16f to d1fc3a9 Compare May 6, 2022 02:22
// In the dependency, changing runtimeOnly to api however will remedy the problem.
// TODO: remove this when gradle problem is fixed.
TestRequirements.notWindows();
Engine engine = Engine.getEngine("PyTorch");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration module is supposed to be for "engine agnostic tests". So, you shouldn't need to specify "PyTorch" here. For the other engines, it will throw a UnsupportedOperationException instead of running successfully. But, our test runner is designed to treat UnsupportedOperationException as a special test result, not a failure. It's similar to how a testng SkipException counts as a skipped test result rather than a test failure.

In the future, we should be able to implement gather for the other engines and then it will just run the test because it is already there.

@@ -52,6 +53,22 @@ public void testPick() {
}
}

@Test
public void testGather() {
// Currently in windows gradle cannot find all the engines.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment still valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably. It seems like something is wrong with runtimeOnly dependencies on windows gradle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes after updating gradle, on GitHub the test stills fails for the same reason

@zachgk zachgk merged commit f673b0b into master May 6, 2022
@zachgk zachgk deleted the gather_dev branch May 6, 2022 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants